# Alex Gazmararian
# agazmararian@gmail.com
# Survey weights calculation for visibility analysis

library(tidyverse)
library(here)
library(tidycensus)
library(survey)

# Load Census API key from .Renviron
if (Sys.getenv("CENSUS_API_KEY") != "") {
  tidycensus::census_api_key(Sys.getenv("CENSUS_API_KEY"))
  message("Census API key loaded from .Renviron")
} else {
  message("Warning: CENSUS_API_KEY not found in .Renviron")
}

message("Loading survey weights calculation functions...")

# Function to standardize covariates with within-state variance
std_within_state <- function(variable, data) {
  f <- as.formula(paste0(variable, " ~ 1 | state"))
  m.out <- fixest::feols(f, data = data)
  m.resid <- sd(resid(m.out))
  new.col <- paste0(variable, "_z")
  data <- data %>%
    mutate(!!new.col := !!sym(variable) / m.resid)
  return(data)
}

# Function to prepare ACS targets for survey weighting
prepare_acs_targets <- function(data_year = 2023, sample_size) {
  
  message("Preparing ACS targets for survey weighting...")
  
  # Cache file for ACS variable definitions (to avoid API call)
  vars_cache_file <- here("data", "cache", paste0("acs_variables_", data_year, ".rds"))
  
  # Load ACS variables (from cache or API)
  if (file.exists(vars_cache_file)) {
    message("[OK] Using cached ACS variable definitions")
    acs.vars <- readRDS(vars_cache_file)
  } else {
    message("Loading ACS variable definitions from Census API...")
    acs.vars <- tidycensus::load_variables(data_year, "acs5")
    # Cache for future use
    dir.create(dirname(vars_cache_file), recursive = TRUE, showWarnings = FALSE)
    saveRDS(acs.vars, vars_cache_file)
    message("[OK] Cached ACS variable definitions to: ", vars_cache_file)
  }
  
  # Define variable groups
  SexAgeEdu <- subset(acs.vars, concept == "Sex by Age by Educational Attainment for the Population 18 Years and Over")
  SexAgeEdu <- SexAgeEdu[stringr::str_count(SexAgeEdu$label, "!") == 8, ]
  
  income <- subset(acs.vars, concept == "Household Income in the Past 12 Months (in 2023 Inflation-Adjusted Dollars)")
  income <- income[-1,]
  
  RaceAgeSex <- subset(acs.vars, concept %in% c("Sex by Age (White Alone)",
                                                "Sex by Age (Black or African American Alone)",
                                                "Sex by Age (American Indian and Alaska Native Alone)",
                                                "Sex by Age (Asian Alone)",
                                                "Sex by Age (Native Hawaiian and Other Pacific Islander Alone)",
                                                "Sex by Age (Some Other Race Alone)",
                                                "Sex by Age (Two or More Races)"))
  RaceAgeSex <- RaceAgeSex[stringr::str_count(RaceAgeSex$label, "!")==6,]
  
  getvars <- c(SexAgeEdu$name, income$name, RaceAgeSex$name, "B01003_001")
  
  # Cache the raw ACS download
  cache_file <- here("data", "cache", paste0("acs_weights_raw_", data_year, ".rds"))
  
  if (file.exists(cache_file)) {
    message("[OK] Using cached ACS weights data from data/cache/")
    acs_long <- readRDS(cache_file)
  } else {
    # Get ACS data from Census API
    message("Downloading ACS data from Census API...")
    acs_long <- tidycensus::get_acs("us", variables = getvars, year = data_year, sumfile = "acs5")
    
    # Cache the raw download
    dir.create(dirname(cache_file), recursive = TRUE, showWarnings = FALSE)
    saveRDS(acs_long, cache_file)
    message("[OK] Cached ACS weights data to: ", cache_file)
  }
  
  # Process Sex-Age-Education targets
  acs.SexAgeEdu <- subset(acs_long, variable %in% SexAgeEdu$name)
  acs.SexAgeEdu <- merge(acs.SexAgeEdu, SexAgeEdu, by.x = "variable", by.y = "name")
  
  joint <- acs.SexAgeEdu %>%
    tidyr::separate(col = label, into = c("estimate2", "total", "gender", "age_acs", "edu"), sep = "!!") %>%
    subset(select = -c(concept, estimate2, total, geography, NAME, GEOID)) %>%
    dplyr::mutate(
      edu_acs = dplyr::case_when(
        edu %in% c("Less than 9th grade", "9th to 12th grade, no diploma", "High school graduate (includes equivalency)", "Associate's degree", "Some college, no degree") ~ 0,
        edu %in% c("Bachelor's degree", "Graduate or professional degree") ~ 1,
        TRUE ~ NA
      )
    ) %>%
    dplyr::group_by(gender, age_acs, edu_acs) %>%
    dplyr::summarise(estimate = sum(estimate), .groups = "drop")
  
  joint$Freq <- joint$estimate / sum(joint$estimate)
  joint$estimate <- NULL
  joint$Freq <- joint$Freq * sample_size
  joint$gender <- gsub(":","",joint$gender)
  joint$age_acs <- gsub(":","",joint$age_acs)
  
  # Process Race-Age-Sex targets
  acs.RaceAgeSex <- subset(acs_long, variable %in% RaceAgeSex$name)
  acs.RaceAgeSex <- merge(acs.RaceAgeSex, RaceAgeSex, by.x = "variable", by.y = "name")
  
  race.dist <- acs.RaceAgeSex %>%
    tidyr::separate(col = label, into = c("estimate2", "total", "gender", "age_acs"), sep = "!!") %>%
    dplyr::select(-c(variable, GEOID, NAME, moe, estimate2, total, geography))
  
  # Clean race category names from ACS concept labels
  race.dist$race_acs <- gsub("Sex by Age \\(", "", race.dist$concept)
  race.dist$race_acs <- gsub("\\)", "", race.dist$race_acs)
  race.dist$concept <- NULL
  
  race.dist <- race.dist %>%
    dplyr::mutate(
      age_acs = dplyr::case_when(
        age_acs %in% c("18 and 19 years", "20 to 24 years") ~ "18 to 24 years",
        age_acs %in% c("25 to 29 years", "30 to 34 years") ~ "25 to 34 years",
        age_acs %in% c("35 to 44 years") ~ "35 to 44 years",
        age_acs %in% c("45 to 54 years", "55 to 64 years") ~ "45 to 64 years",
        age_acs %in% c("65 to 74 years", "75 to 84 years", "85 years and over") ~ "65 years and over",
        TRUE ~ NA_character_
      )
    ) %>%
    subset(!is.na(age_acs)) %>%
    mutate(
      race_acs = case_when(
        race_acs %in% c("Native Hawaiian and Other Pacific Islander Alone",
                        "Some Other Race Alone", "Two or More Races",
                        "American Indian and Alaska Native Alone") ~ "Other",
        race_acs == "Black or African American Alone" ~ "Black or African American Alone",
        race_acs == "Asian Alone" ~ "Asian Alone",
        race_acs == "White Alone" ~ "White Alone",
        TRUE ~ "Other"
      )
    ) %>%
    dplyr::group_by(race_acs) %>%
    dplyr::summarise(estimate = sum(estimate), .groups = "drop")
  
  race.dist$Freq <- race.dist$estimate / sum(race.dist$estimate)
  race.dist$Freq <- race.dist$Freq * sample_size
  race.dist$estimate <- NULL
  
  # Process Income targets
  acs.income <- subset(acs_long, variable %in% income$name)
  acs.income <- merge(acs.income, income, by.x = "variable", by.y = "name")
  income.dist <- acs.income %>%
    tidyr::separate(col = label, into = c("estimate2", "total", "income"), sep = "!!") %>%
    subset(select = -c(concept, estimate2, total, geography, NAME, GEOID)) %>%
    dplyr::mutate(income_acs = income) %>%
    mutate(income_acs = case_when(
      income_acs %in% c("Less than $10,000", "$10,000 to $14,999", "$15,000 to $19,999",
        "$20,000 to $24,999", "$25,000 to $29,999") ~ "Q1",
      income_acs %in% c("$30,000 to $34,999", "$35,000 to $39,999","$40,000 to $44,999", "$45,000 to $49,999",
        "$50,000 to $59,999") ~ "Q2",
      income_acs %in% c("$60,000 to $74,999", "$75,000 to $99,999") ~ "Q3",
      income_acs %in% c("$100,000 to $124,999", "$125,000 to $149,999") ~ "Q4",
      income_acs %in% c("$150,000 to $199,999", "$200,000 or more") ~ "Q5",
      TRUE ~ NA
    )) %>%
    dplyr::group_by(income_acs) %>%
    dplyr::summarise(estimate = sum(estimate), .groups = "drop")
  
  income.dist$Freq <- income.dist$estimate / sum(income.dist$estimate)
  income.dist$estimate <- NULL
  income.dist$Freq <- income.dist$Freq * sample_size
  
  # Regional targets
  midwest_freq <- 0.205
  northeast_freq <- 0.17
  south_freq <- 0.39
  west_freq <- 0.235
  
  # Check totals
  if (midwest_freq + northeast_freq + south_freq + west_freq != 1) {
    stop("Region totals do not sum to 1")
  }
  
  region <- data.frame(
    "region" = c("Midwest", "Northeast", "South", "West"),
    "Freq" = c(midwest_freq, northeast_freq, south_freq, west_freq)
  )
  region$Freq <- region$Freq * sample_size
  
  return(list(
    joint = joint,
    race = race.dist,
    income = income.dist,
    region = region
  ))
}

# Main function to calculate survey weights using raking
calculate_survey_weights <- function(data, 
                                   joint_targets = NULL,
                                   race_targets = NULL, 
                                   income_targets = NULL,
                                   region_targets = NULL,
                                   weight_bounds = c(0.3, 3),
                                   verbose = TRUE) {
  
  # If targets not provided, generate them
  if (is.null(joint_targets) || is.null(race_targets) || is.null(income_targets) || is.null(region_targets)) {
    targets <- prepare_acs_targets(data_year = 2023, sample_size = nrow(data))
    joint_targets <- targets$joint
    race_targets <- targets$race
    income_targets <- targets$income
    region_targets <- targets$region
  }
  
  # Validate required variables exist
  required_vars <- c("gender", "age_acs", "edu_acs", "race_acs", "income_acs", "region")
  missing_vars <- setdiff(required_vars, names(data))
  if(length(missing_vars) > 0) {
    stop("Missing required variables: ", paste(missing_vars, collapse = ", "))
  }
  
  # Create survey design object
  design_unweight <- suppressWarnings(svydesign(ids = ~1, data = data))
  
  # Scale population targets to sample size
  joint_scaled <- joint_targets %>%
    mutate(Freq = Freq * nrow(data) / sum(joint_targets$Freq))
  
  race_scaled <- race_targets %>%
    mutate(Freq = Freq * nrow(data) / sum(race_targets$Freq))
    
  income_scaled <- income_targets %>%
    mutate(Freq = Freq * nrow(data) / sum(income_targets$Freq))
    
  region_scaled <- region_targets %>%
    mutate(Freq = Freq * nrow(data) / sum(region_targets$Freq))
  
  if(verbose) {
    message("Calculating weights for ", nrow(data), " observations...")
  }
  
  # Apply raking (suppress standard raking warnings about equal probability assumptions)
  design_weighted <- rake(
    design = design_unweight,
    sample.margins = list(
      ~gender + age_acs + edu_acs,
      ~race_acs, 
      ~income_acs,
      ~region
    ),
    population.margins = list(
      joint_scaled,
      race_scaled,
      income_scaled,
      region_scaled
    ),
    control = list(maxit = 50, epsilon = 1e-7, verbose = verbose)
  )
  
  # Get base weights
  base_weights <- weights(design_weighted)
  
  # Apply weight trimming
  design_trimmed <- trimWeights(design_weighted, 
                               lower = weight_bounds[1], 
                               upper = weight_bounds[2], 
                               strict = TRUE)
  trimmed_weights <- weights(design_trimmed)
  
  if(verbose) {
    message("Weight summary:")
    message("  Base weights - Min: ", round(min(base_weights), 3), 
            ", Max: ", round(max(base_weights), 3),
            ", Mean: ", round(mean(base_weights), 3))
    message("  Trimmed weights - Min: ", round(min(trimmed_weights), 3),
            ", Max: ", round(max(trimmed_weights), 3), 
            ", Mean: ", round(mean(trimmed_weights), 3))
    
    # Check marginal deviations
    joint_check <- svytable(~gender + age_acs + edu_acs, design_weighted)
    race_check <- svytable(~race_acs, design_weighted)
    income_check <- svytable(~income_acs, design_weighted)
    region_check <- svytable(~region, design_weighted)
    
    edu_dev <- max(abs(joint_check - joint_scaled$Freq))
    race_dev <- max(abs(race_check - race_scaled$Freq))
    income_dev <- max(abs(income_check - income_scaled$Freq))
    region_dev <- max(abs(region_check - region_scaled$Freq))
    
    message("  Marginal deviations - Education: ", round(edu_dev, 2),
            ", Race: ", round(race_dev, 2),
            ", Income: ", round(income_dev, 2), 
            ", Region: ", round(region_dev, 2))
  }
  
  # Return list with weights and design objects
  return(list(
    weights = base_weights,
    weights_trimmed = trimmed_weights,
    design = design_weighted,
    design_trimmed = design_trimmed
  ))
}

# Function to prepare survey data for weighting
prepare_survey_for_weighting <- function(data) {
  
  message("Preparing survey data for weighting...")
  
  # Add region classification
  data <- data %>%
    mutate(
      region = case_when(
        state %in% c("Connecticut", "Maine", "Massachusetts", "New Hampshire", "Rhode Island", "Vermont", "New Jersey", "New York", "Pennsylvania") ~ "Northeast",
        state %in% c("Illinois", "Indiana", "Michigan", "Ohio", "Wisconsin", "Iowa", "Kansas", "Minnesota", "Missouri", "Nebraska", "North Dakota", "South Dakota") ~ "Midwest",
        state %in% c("Delaware", "District of Columbia", "Florida", "Georgia", "Maryland", "North Carolina", "South Carolina", "Virginia", "West Virginia", "Alabama", "Kentucky", "Mississippi", "Tennessee", "Arkansas", "Louisiana", "Oklahoma", "Texas") ~ "South",
        state %in% c("Arizona", "Colorado", "Idaho", "Montana", "Nevada", "New Mexico", "Utah", "Wyoming", "Alaska", "California", "Hawaii", "Oregon", "Washington") ~ "West"
      )
    )
  
  # Add ACS-compatible age categories
  data <- data %>%
    mutate(
      age = as.integer(age),
      age_acs = case_when(
        age >= 18 & age <= 24 ~ "18 to 24 years",
        age >= 25 & age <= 34 ~ "25 to 34 years",
        age >= 35 & age <= 44 ~ "35 to 44 years",
        age >= 45 & age <= 64 ~ "45 to 64 years",
        age >= 65 ~ "65 years and over",
        TRUE ~ NA
      ),
      race_acs = case_when(
        black == 1 & white == 0 & asian == 0 & otherrace == 0 ~ "Black or African American Alone",
        white == 1 & black == 0 & asian == 0 & otherrace == 0 ~ "White Alone",
        asian == 1 & black == 0 & white == 0 & otherrace == 0 ~ "Asian Alone",
        TRUE ~ "Other"
      ),
      edu_acs = college,
      income_acs = income5,
      gender = ifelse(female == 1, "Female", "Male")
    )
  
  return(data)
}

message("Survey weights functions loaded successfully!")
